clear; 
rng('default');

input_folder = [fullfile(pwd, '..', 'data') filesep];

output_folder = [fullfile(pwd, '..') filesep];

%% setup
use_gpu = false;

years_in_sample = 2010 : 2016;

n_smln = 350; % # of simulation draws

pfid = fopen('log_summary_stats_beer.txt', 'w+');

% number of starting values for estimating demand
n_trial = 20;

dim_price = 1;

setup.tol_berry_inversion = 1e-5;

setup.n_iter = 100;


%% load data and construct the estimation sample
n_trips = 8; % number of purchase opportunities per month

p_cutoff = 30;% remove very high priced items

% load data and define market size
k = 0;

beerdata_agg = [];

for yy = years_in_sample
    
    k = k + 1;    
    fprintf(pfid, 'construct data, year %1.0f\n', yy);
    
    filename = [input_folder 'beerdata_step3_all_var_' num2str(yy) '.mat'];
    
   
    load([input_folder 'intermediatefile_sample_agg_brand_retailer_county_year_month_', num2str(yy) '.mat'], 'sample_beerdata_agg_1year');
   
    beerdata_agg = [beerdata_agg; sample_beerdata_agg_1year];
    
    mktsize_file = load([input_folder 'all_alcohol_sales_' num2str(yy) '.mat'], 'quantity', 'store_code', 'yearmonth');
    mktsize_file = struct2table(mktsize_file);
    
    % get retailer and county id for each store
    mapping = load(filename, 'retail_id','store_county_id', 'store_code');  % mapping from store to its retailer and county
    mapping = struct2table(mapping);
    [~, tmpind] = unique(mapping.store_code);
    mapping = mapping(tmpind, :);
    mapping.retail_cty_id = mapping.retail_id*1e7 + mapping.store_county_id; 
    mapping = removevars(mapping, {'retail_id', 'store_county_id'});

    ind = ~ismember(mktsize_file.store_code, mapping.store_code);
    mktsize_file(ind, :) = [];
    mktsize_file = join(mktsize_file, mapping, 'Keys', 'store_code');

    % compute the total quanity for each mkt/month
    mktsize_file = varfun(@sum, mktsize_file, 'InputVariables', 'quantity', 'GroupingVariables',{'retail_cty_id','yearmonth'});

    % merge with sample_beerdata_agg_1year
    sample_beerdata_agg_1year.retail_cty_id = sample_beerdata_agg_1year.retail_id*1e7 + sample_beerdata_agg_1year.store_county_id;
    sample_beerdata_agg_1year = join(sample_beerdata_agg_1year, mktsize_file, 'Keys', {'retail_cty_id', 'yearmonth'});
    mktsize_yy = sample_beerdata_agg_1year.sum_quantity/144*n_trips;

    if k == 1
        mktsize = mktsize_yy;
    else
        mktsize = [mktsize; mktsize_yy];
    end
end


if any(isnan(mktsize)); error('nan in mktsize_yy'); end;
beerdata_agg.mktsize = mktsize; 
clear sample_beerdata_agg_1year tmpind mktsize_yy mktsize_file mktsize ind mapping

% remove very high price products
beerdata_agg = beerdata_agg(beerdata_agg.price<p_cutoff, :);

% define mkt id
[~, ~, beerdata_agg.mktid] = unique(beerdata_agg(:, {'store_county_id', 'retail_id', 'yearmonth'}), 'rows');
beerdata_agg = sortrows(beerdata_agg, 'mktid');    

% unpack the table
varlist = beerdata_agg.Properties.VariableNames;
for i = 1:length(varlist)
    eval([varlist{i} '= beerdata_agg.' varlist{i} ';']);
end
clear beerdata_agg;

% new variables to be defined
varlist = [varlist, 'mktsh', 'dummy_lager', 'dummy_light', 'dummy_ale', 'mo', 'yr', 'brand_descr'];

% mkt share
mktsh = quantity ./ mktsize;

% adjust prices for inflation
mo = mod(yearmonth, 100);
yr = (yearmonth - mo)/100;

inflation_tbl = [2010, 1.1;
                2011, 1.07;
                2012, 1.05;
                2013, 1.03;
                2014, 1.01;
                2015, 1.01;
                2016, 1];
            
for j = 1 : size(inflation_tbl, 1)   
    indj = yr==inflation_tbl(j, 1);    
    price(indj) = price(indj)*inflation_tbl(j, 2);    
end
clear indj

% product module dummies
dummy_lager = product_module_code == 5000;
dummy_light = product_module_code == 5010;
dummy_ale = product_module_code == 5015;

% define craft as ever appearing in the list of craft beer association
brandfile = load([input_folder 'beerdata_step2_brewer_owner_08to14.mat'],...
    'brand_descr','brand_id', 'owner', 'owner_id', 'st_craftbrewer', 'brewer', 'brewer_id');
brandfile = struct2table(brandfile);
[tmp1, ~, brand_fe] = unique(brand_id);
[tmp2, ~, ib] = intersect(brand_id, brandfile.brand_id);
assert(isequal(tmp1, tmp2)); clear tmp1 tmp2

tmp = brandfile.brand_descr(ib);
brand_descr = tmp(brand_fe); clear brand_fe

craft_cat = readall(datastore([input_folder, 'craftbrand.csv']));
craftdummy = ismember(brand_descr, table2array(craft_cat)); clear craft_cat


% construct sample
s0 = grpstats(mktsh, mktid, 'sum');
s0 = s0(mktid);
s0 = 1 - s0;

choose_mkt_all = load([input_folder, 'choose_mkt_9_2009_2016']); % load data selection criteria from choose_mkt_CA_2009_2016

ind_cty = false(size(brand_id));
mix_ind = int64(store_county_id*(10^10)) + int64(retail_id*10^6) + int64(brand_id);
for ny = 1 : length(years_in_sample)    
    geomkt_brand_for_est = choose_mkt_all.choose_mkt{years_in_sample(ny)-2006, 1};    
    q_cutoff_select = choose_mkt_all.choose_mkt{years_in_sample(ny)-2006, 2};
    
    mix_ind_ny = int64(geomkt_brand_for_est(:, 1)*10^10) + int64(geomkt_brand_for_est(:, 2)*10^6) + int64(geomkt_brand_for_est(:, 3));
    
    ind_cty_j = ismember(mix_ind, mix_ind_ny) & quantity>q_cutoff_select...
            & yearmonth>=years_in_sample(ny)*100+1 & yearmonth<=years_in_sample(ny)*100+12;    
    
    ind_cty_j = ind_cty_j & s0<quantile(s0(ind_cty_j), 0.99) & s0>quantile(s0(ind_cty_j), 0.01);
    
    ind_cty = ind_cty | ind_cty_j;    
end
clear mix_ind mix_ind_ny ind_cty_j

for i = 1:length(varlist)
    eval([varlist{i} '= ' varlist{i} '(ind_cty, :);']);
end

% a market must be present in every month of the year
[~, ~, yrgeomkt] = unique([yr, store_county_id, retail_id], 'rows');
yrgeomkt_list = unique(yrgeomkt);
nm_yrgeomkt = nan(length(yrgeomkt_list), 1);
for gg = 1 : length(yrgeomkt_list)   
    nm_yrgeomkt(gg) = length(unique(mo(yrgeomkt==yrgeomkt_list(gg))));        
end
ind_cty = ismember(yrgeomkt, find(nm_yrgeomkt==12));


for i = 1:length(varlist)
    eval([varlist{i} '= ' varlist{i} '(ind_cty, :);']);
end
clear ind_cty

[~, ~, mktid] = unique([store_county_id, retail_id, yearmonth], 'rows');

s0 = grpstats(mktsh, mktid, 'sum');
s0 = s0(mktid);
s0 = 1 - s0;

clear varlist nm_yrgeomkt yrgeomkt yrgeomkt_list

%% define variables in estimation
% 1. instruments
% load monthly price data
barley = readall(datastore([input_folder, 'barley_price.csv']));
[ym_list, ~, ym_ind] = unique(yearmonth);     

ym_list_all = ((2009:2016)'*100 + (1:12))';
ym_list_all = ym_list_all(:);

iv = nan(size(ym_list_all, 1), 1);
for k = 1 : length(ym_list_all)    
    mk = mod(ym_list_all(k), 100);    
    yk = floor((ym_list_all(k)-mk)/100);    
    iv(k) = barley.Value(year(barley.Date)==yk & month(barley.Date)==mk);    
end

% find the starting point of the actual ym_list
start_ym = find(ym_list_all == ym_list(1));
iv = iv(start_ym : (start_ym+length(ym_list)-1));
iv = iv(ym_ind);
clear ym_ind

% 2. model specific categorical variables interacted with random variables
% 2.1. load microdata, calculate the weights for each market-year
microdata2 = [];
microdata_all = [];

code_sample_id = rand(length(years_in_sample), 1);% sample specific identifiers (will be added to the household code so that we do not track households across years)

for j = 1 : length(years_in_sample)
    load([input_folder, 'beerdata_micro_', num2str(years_in_sample(j))], 'microdata');
    microdata_all = [microdata_all; microdata];

    % select households matching the county    
    sample_ind = yr==years_in_sample(j);    
    county_retail_in_sample = unique([store_county_id(sample_ind), retail_id(sample_ind)], 'rows');
    ind_micro_cr = ismember([microdata.hhl_county microdata.retailer_code], county_retail_in_sample, 'rows');
    microdata_matched = microdata(ind_micro_cr, :);
   
    % modify the household code to make it sample specific
    microdata_matched.household_code = microdata_matched.household_code + code_sample_id(j);
    
    microdata2 = [microdata2; microdata_matched];    
end
clear code_sample_id ind_micro_cr

% 2.2 get coordinates to calculate brand-market distance and identify foreign brands
% load coordinates of breweries
brewer_loc_tbl = readall(datastore([input_folder, 'brewer_loc.csv'])); 
brewer_loc_tbl = renamevars(brewer_loc_tbl,["Var1","Var2","Var3","Var4","Var5","Var6","Var7"],["brewer","Var2","Var3","latitude_brewer","longtitude_brewer","import","foreign_flavor"]);

% match brewers to brewer_id
[~,ind] = unique(brandfile.brewer);
tmp = brandfile(ind,{'brewer', 'brewer_id'});
brewer_loc_tbl = join(brewer_loc_tbl, tmp, 'RightVariables', 'brewer_id', 'Keys', 'brewer');

% construct a table with brewer_id and county information
brewer_cty_tbl = table(brewer_id, store_county_id);

% add county coordinates to brewer_cty_tbl
geo_tbl = readall(datastore([input_folder, 'CenPop2010_Mean_CO.csv'])); 
geo_tbl.store_county_id = geo_tbl.COUNTYFP + geo_tbl.STATEFP*1000;
geo_tbl = renamevars(geo_tbl, ["LATITUDE","LONGITUDE"], ["latitude_cty","longtitude_cty"]);

brewer_cty_tbl = join(brewer_cty_tbl, geo_tbl, 'Keys', 'store_county_id');

% use the same brewer_id for the merged brewers
merge_id = unique(brewer_loc_tbl.brewer_id(ismember(brewer_loc_tbl.brewer, {'miller brewing co', 'coors brewing co'})));
brewer_loc_tbl.brewer_id(ismember(brewer_loc_tbl.brewer_id, merge_id)) = merge_id(1);
brewer_cty_tbl.brewer_id(ismember(brewer_cty_tbl.brewer_id, merge_id)) = merge_id(1);

% add brewer coordinates to brewer_cty_tbl
idx = knnsearch([brewer_loc_tbl.brewer_id*1e8, brewer_loc_tbl.latitude_brewer, brewer_loc_tbl.longtitude_brewer], [brewer_cty_tbl.brewer_id*1e8, brewer_cty_tbl.latitude_cty, brewer_cty_tbl.longtitude_cty]); 
brewer_cty_tbl.latitude_brewer = brewer_loc_tbl.latitude_brewer(idx);
brewer_cty_tbl.longtitude_brewer = brewer_loc_tbl.longtitude_brewer(idx);
clear idx

% compute distance between a county and a brewer's location
cty_br_dist = lldistkm_vec([brewer_cty_tbl.latitude_cty brewer_cty_tbl.longtitude_cty], [brewer_cty_tbl.latitude_brewer brewer_cty_tbl.longtitude_brewer]);

% foreign flavor
dummy_ff = ismember(brewer_cty_tbl.brewer_id, brewer_loc_tbl.brewer_id(brewer_loc_tbl.foreign_flavor == 1));
brand_ff = unique(brand_id(dummy_ff));% foreign flavored brands (to be used for identifying foreign beers in the micro data)

clear brewer_cty_tbl;

% covariates for estimation
x_rand = [dummy_lager, dummy_light, craftdummy, dummy_ff, dummy_ale];
x_rand2 = craftdummy;

% normal coefficients on dim_rand and intercept
% income on price and intercept
% income on dim_rand
para_ub = [  2   2.8 2.8 2.2  3  10  10 5.00  -2 35];
para_lb = [-20    -9 0.2  -8 -8 -10 -10 0.00 -35 -3];

randomdraw = norminv(net(scramble(haltonset(size(x_rand, 2)+1,'Skip',1e3,'Leap',1e2), 'RR2'), n_smln)');
randomdraw = chol(inv(cov(randomdraw')))*randomdraw;

ga_init_range = [para_lb; para_ub];
startpts = [0.1029    1.0663    0.85    0.9248    0.8277   -8.0241   -0.8092    1.0124  -15.9037    7.7479];

dummy_data = [microdata2.product_module_code==5000, ...%lager
    microdata2.product_module_code==5010, ...%light
    microdata2.craftdummy, ...
    ismember(microdata2.brand_code_uc_data_prod, brand_ff), ...
    microdata2.product_module_code==5015];%ale

[~, ~, geomkt] = unique([store_county_id, retail_id], 'rows');

% verify whether each retailer-county shows up in every month in a year
for yy = years_in_sample    
    geomkt_yy = unique(geomkt(yr==yy));    
    for gg = 1 : length(geomkt_yy)
       if ~isequal(1:12, unique(mo(geomkt==geomkt_yy(gg) & yr==yy))')           
           error('a market is not present for the full year');           
       end        
    end    
end
clear geomkt_yy

%% estimation
if ~isequal(unique(mktid), (1:max(mktid))'); error('mktid goes from 1 to max(mktid) with no gap');end;

logshare = log(mktsh);
setup.meanval_logit = logshare - log(s0);
setup.mktid = mktid;
setup.pfid = pfid;

ind_mkt_end = [find(diff(mktid)); length(mktid)];
ind_mkt_start = [1; ind_mkt_end(1:end-1) + 1];

% take random draws for demographic variables
beta_cutoff = [0.5 1 1.5 2 2.5, 3, 5, 7.5, 10]*1e4;
max_beta = 30*1e4;

n_geo = max(geomkt);
rand_income_geo = nan(n_geo, n_smln);
for ng = 1 : n_geo
    store_county_id_ng = unique(store_county_id(geomkt==ng));
    retail_id_ng = unique(retail_id(geomkt==ng));
    ind_ng = microdata_all.hhl_county == store_county_id_ng & microdata_all.retailer_code==retail_id_ng;
    
    if nnz(ind_ng)<20
        ind_ng = microdata_all.hhl_county == store_county_id_ng;
    end
    
    household_weight_ng = microdata_all.projection_factor(ind_ng);
    household_income_ng = microdata_all.household_income(ind_ng);
        
    % calibrate a beta distribution based on CDF
    p_cutoff = ...
        sum((household_income_ng<beta_cutoff).*household_weight_ng)./sum(household_weight_ng);
    
    p_beta = @(pb) sum((betacdf(beta_cutoff./max_beta, pb(1), pb(2)) - p_cutoff).^2);
    
    p_beta_est = fminsearch(p_beta, [2, 5]);
    
    % simulate household income based on the beta distribution
    rand_income_geo(ng, :) = betainv(rand(1, n_smln), p_beta_est(1), p_beta_est(2))*max_beta;
end

rand_income = rand_income_geo(geomkt, :); clear rand_income_geo household_weight_ng household_income_ng store_county_id_ng retail_id_ng ind_ng
 
%% micro moments
% track household-market level behavior
[~, ~, household_mkt] = unique([microdata2.household_code microdata2.hhl_county microdata2.retailer_code], 'rows');

% income and craft purchases
cutoff_craft_inc = [5 10]*1e4;
cutoff = [5 10]*1e4;

[micro_m_data, maxid, mi_craft_ind] = gen_micro_moment(household_mkt,...
    microdata2.quantity, dummy_data, microdata2.price, ...
    microdata2.household_income, microdata2.projection_factor, cutoff, ...
    cutoff_craft_inc, []);

% use bootstrap for variance
n_boot = 500;
n_moments = length(micro_m_data);
bootm = nan(n_moments, n_boot);
for n = 1 : n_boot    
    id_boot = randi(maxid, [maxid, 1]);    
    m_n = gen_micro_moment(household_mkt,...
        microdata2.quantity, dummy_data, microdata2.price, ...
        microdata2.household_income, microdata2.projection_factor, cutoff,...
        cutoff_craft_inc, id_boot);    
    bootm(:, n) = m_n;    
end

W2 = diag(diag(pinv(cov(bootm'))));

W2(8,8) = 1e3; 
W2(7,7) = 1e3;
W2(10,10)=1e3;
W2(6, 6) = 1e3;



[~, ~, brand_id2] = unique(brand_id);            
mo_dummy = dummyvar(mo);

% dummies for distances
% those between [0 50KM] are the baseline
dist_x = dist_dummy(cty_br_dist);  clear cty_br_dist
dist_craft = dist_x(:, 2:end).*craftdummy;  

x = [price, dist_x, dist_craft, mo_dummy(:, 1:11)];    
iv_x = [iv.*x_rand dist_x dist_craft mo_dummy(:, 1:11)];

% alternating projection    
mean1 = @(x) mean(x, 1);
mean_x_brand = grpstats(x, brand_id2, mean1);    
mean_iv_brand = grpstats(iv_x, brand_id2, mean1);

demeaned_x = x - mean_x_brand(brand_id2, :);    
demeaned_iv = iv_x - mean_iv_brand(brand_id2, :);

mktfe_id = geomkt;

diffxiv = 10;    
while diffxiv>1e-3        
    dx_prev = demeaned_x;        
    div_prev = demeaned_iv;
    
    mean_x_geo = grpstats(demeaned_x, mktfe_id, mean1);    
    mean_iv_geo = grpstats(demeaned_iv, mktfe_id, mean1);
    
    demeaned_x = demeaned_x - mean_x_geo(mktfe_id, :);    
    demeaned_iv = demeaned_iv - mean_iv_geo(mktfe_id, :);
    
    mean_x_brand = grpstats(demeaned_x, brand_id2, mean1);        
    mean_iv_brand = grpstats(demeaned_iv, brand_id2, mean1);
    
    demeaned_x = demeaned_x - mean_x_brand(brand_id2, :);        
    demeaned_iv = demeaned_iv - mean_iv_brand(brand_id2, :);
    
    diffxiv = max(abs([demeaned_x(:) - dx_prev(:); demeaned_iv(:) - div_prev(:)]));        
end
clear mean_x_brand mean_iv_brand mean_x_geo mean_iv_geo mktfe_id

% weighting matrix
W1 = inv(cov(demeaned_iv));

% indices
[iS1_brand, iS2_brand] = sort(brand_id2);    
n_obs_per_brand = grpstats(ones(size(brand_id2, 1), 1), iS1_brand, 'sum');% number of obs for each brand    
ind_obs_brand = cumsum(n_obs_per_brand);

[iS1_geomkt, iS2_geomkt] = sort(geomkt);
n_obs_per_geomkt = grpstats(ones(size(geomkt, 1), 1), iS1_geomkt, 'sum'); % number of obs for each geomkt    
ind_obs_geomkt = cumsum(n_obs_per_geomkt);


% how long a household is tracked
T = 12;

% verify if data are sorted by geomkt and then by time
for j = 1 : max(geomkt)    
    if any(diff(yearmonth(geomkt==j))<0)        
        error('data should be sorted by time within a geo market');        
    end
end

[~, ind] = unique([geomkt, yr], 'rows');
mktsize_geomktyr = mktsize(ind);

clear  ans channel_id choose_mkt_all gg host i ib id_boot ind inflation_tbl x iv iv_x j k m_n max_beta mean1 ...
    microdata_all microdata2 microdata mk n n_boot n_geo n_moments ng ny tmp yk yy

loginc = log(rand_income);
logcutoff = log(cutoff);
logcutoff_craft = log(cutoff_craft_inc);
logincnorm = 10;
inc_mkt = mean(rand_income, 2); clear rand_income; 
save([output_folder, 'est_demand_variables_for_est'],'-v7.3');

if use_gpu
    logshare = gpuArray(logshare);
    price = gpuArray(price);
    x_rand = gpuArray(x_rand);
    micro_m_data = gpuArray(micro_m_data);
    W1 = gpuArray(W1);
    W2 = gpuArray(W2);
    
    randomdraw = gpuArray(randomdraw);
    loginc = gpuArray(loginc);
    logincnorm = gpuArray(logincnorm);
    logcutoff = gpuArray(logcutoff);
    logcutoff_craft = gpuArray(logcutoff_craft);
    demeaned_x = gpuArray(demeaned_x);
    demeaned_iv = gpuArray(demeaned_iv);
    n_obs_per_brand = gpuArray(n_obs_per_brand);
    n_obs_per_geomkt = gpuArray(n_obs_per_geomkt);

    setup.meanval_logit = gpuArray(setup.meanval_logit);
    setup.mktid = gpuArray(setup.mktid);
    mktsize_geomktyr = gpuArray(mktsize_geomktyr);
    x_rand2 = gpuArray(x_rand2);

    obj = @(para_nonlinear) gmmobj_demand_beer_gpu(para_nonlinear, [], ...
        logshare, price,  micro_m_data,  ...
        W1,  W2,  mktid, x_rand,  ...
        randomdraw,  loginc,  logincnorm,  ...
        logcutoff,  logcutoff_craft,  ind_mkt_start, ind_mkt_end, n_trips, ...
        demeaned_x,  demeaned_iv, brand_id2, geomkt, setup, ...
        T, iS2_brand, n_obs_per_brand,  ind_obs_brand, ...
        iS2_geomkt, n_obs_per_geomkt,  ind_obs_geomkt, ...
        mktsize_geomktyr,  false, true, x_rand2);

else

    obj = @(para_nonlinear) gmmobj_demand_beer_cpu(para_nonlinear, [], ...
        logshare, price,  micro_m_data,  ...
        W1,  W2,  mktid, x_rand, ...
        randomdraw,  loginc,  logincnorm,  ...
        logcutoff,  logcutoff_craft,  ind_mkt_start, ind_mkt_end, n_trips, ...
        demeaned_x,  demeaned_iv, brand_id2, geomkt, setup, ...
        T, iS2_brand, n_obs_per_brand,  ind_obs_brand, ...
        iS2_geomkt, n_obs_per_geomkt,  ind_obs_geomkt, ...
        mktsize_geomktyr,  false, true, x_rand2);
end

if use_gpu
    use_parallel = false;
else
    use_parallel = true;
    parpool('local', maxNumCompThreads);
end



para_trial = nan(n_trial, size(ga_init_range, 2));        
fval_trial = nan(n_trial, 1);        
for nt = 1 : n_trial
    if nt==1
        startpts_nt = startpts;        
    else
        startpts_nt = startpts+randn(size(startpts))/10;
    end
    

    opts_sur = optimoptions(@surrogateopt, 'Display','final',...
            'MaxFunctionEvaluations', 1000, 'InitialPoints',startpts_nt, ...
            'UseParallel', use_parallel);

    [para_nt, fval_nt] = surrogateopt(obj, para_lb, para_ub, opts_sur);

    para_trial(nt, :) = para_nt;

    fval_trial(nt) = fval_nt;

    save([output_folder, 'est_demand.mat'], 'para_trial', 'fval_trial');

end

[~, min_ind] = min(fval_trial);
para_nonlinear = gather(para_trial(min_ind, :));
save([output_folder, 'est_demand.mat'], 'para_trial', 'fval_trial','para_nonlinear');


[minf, minf_ind] = min(fval_trial);
para_nonlinear = para_trial(minf_ind, :);

para_nonlinear_transformed = para_nonlinear;
para_nonlinear_transformed(1:size(x_rand, 2)+1) = exp(para_nonlinear(1:size(x_rand, 2)+1));

[~, para_linear] = obj(para_nonlinear);

fprintf(1, 'para_nonlinear_transformed ='); fprintf(1, ' %1.6e', para_nonlinear_transformed); fprintf(1, '\n');
fprintf(1, 'para_linear ='); fprintf(1, ' %1.6e', para_linear); fprintf(1, '\n');
